{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Linear probing demo \n",
    "In this notebook, you can evalate slide embeddings for TITAN using linear probing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import yaml\n",
    "\n",
    "from transformers import AutoModel\n",
    "from titan.eval_linear_probe import train_and_evaluate_logistic_regression_with_val\n",
    "from titan.utils import bootstrap\n",
    "\n",
    "import os\n",
    "os.environ[\"OMP_NUM_THREADS\"] = \"8\"\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model from huggingface\n",
    "model = AutoModel.from_pretrained('MahmoodLab/TITAN', trust_remote_code=True)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load task configs\n",
    "with open('../datasets/config_tcga-ot.yaml', 'r') as file:\n",
    "    task_config = yaml.load(file, Loader=yaml.FullLoader)\n",
    "target = task_config['target']\n",
    "label_dict = task_config['label_dict']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load pre-extracted TITAN slide embeddings for TCGA\n",
    "import pickle\n",
    "from huggingface_hub import hf_hub_download\n",
    "slide_feature_path = hf_hub_download(\n",
    "    \"MahmoodLab/TITAN\", \n",
    "    filename=\"TCGA_TITAN_features.pkl\",\n",
    ")\n",
    "with open(slide_feature_path, 'rb') as file:\n",
    "  data = pickle.load(file)\n",
    "embeddings_df = pd.DataFrame({'slide_id': data['filenames'], 'embeddings': list(data['embeddings'][:])})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load splits\n",
    "train_split = pd.read_csv('../datasets/tcga-ot_train.csv')\n",
    "train_df = pd.merge(embeddings_df, train_split, on='slide_id')\n",
    "val_split = pd.read_csv('../datasets/tcga-ot_val.csv')\n",
    "val_df = pd.merge(embeddings_df, val_split, on='slide_id')\n",
    "test_split = pd.read_csv('../datasets/tcga-ot_test.csv')\n",
    "test_df = pd.merge(embeddings_df, test_split, on='slide_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = np.stack(train_df.embeddings.values)\n",
    "train_labels = train_df[target].apply(lambda x: label_dict[x]).values\n",
    "val_data = np.stack(val_df.embeddings.values)\n",
    "val_labels = val_df[target].apply(lambda x: label_dict[x]).values\n",
    "test_data = np.stack(test_df.embeddings.values)\n",
    "test_labels = test_df[target].apply(lambda x: label_dict[x]).values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_spaced_values = np.logspace(np.log10(10e-2), np.log10(10e2), num=3)\n",
    "results, outputs = train_and_evaluate_logistic_regression_with_val(train_data, train_labels, val_data, val_labels, test_data, test_labels, log_spaced_values=log_spaced_values)\n",
    "# to use the default setting from our paper use the default value for searching C (log_spaced_values = np.logspace(np.log10(10e-6), np.log10(10e5), num=45))\n",
    "# results = train_and_evaluate_logistic_regression_with_val(train_data, train_labels, val_data, val_labels, test_data, test_labels)\n",
    "for key, value in results.items():\n",
    "    print(f\"{key.split('/')[-1]: <12}: {value:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bootstrap_kwargs = {'n': 1000, 'alpha': 0.95}\n",
    "results_mean, results_std = bootstrap(results_dict=outputs, **bootstrap_kwargs)  # takes a while as 46 imbalanced classes are bootstrapped\n",
    "for keys, values in results_mean.items():\n",
    "    print(f\"{keys.split('/')[-1]: <12}: {values:.4f} ± {results_std[keys]:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "titan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
